#!/usr/bin/env python3
# libvirt 'run' programs locally script
# Copyright (C) 2012-2021 Red Hat, Inc.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; If not, see
# <http://www.gnu.org/licenses/>.

# ----------------------------------------------------------------------
#
# With this script you can run libvirt programs without needing to
# install them first.  You just have to do for example:
#
#   ./run virsh [args ...]
#
# Note that this runs the locally compiled copy of virsh which
# is usually want you want.
#
# You can also run the C programs under valgrind like this:
#
#   ./run valgrind [valgrind opts...] ./program
#
# or under gdb:
#
#   ./run gdb --args ./program
#
# This also works with sudo (eg. if you need root access for libvirt):
#
#   sudo ./run virsh list --all
#
# ----------------------------------------------------------------------

import argparse
import os
import os.path
import random
import shutil
import signal
import subprocess
import sys


# Function to intelligently prepend a path to an environment variable.
# See https://stackoverflow.com/a/9631350
def prepend(env, varname, extradir):
    if varname in os.environ:
        env[varname] = extradir + ":" + env[varname]
    else:
        env[varname] = extradir


here = "@abs_builddir@"

parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
parser.add_argument('--selinux',
                    action='store_true',
                    help='Run in the appropriate selinux context')

opts, args = parser.parse_known_args()

if len(args) < 1:
    print("syntax: %s [--selinux] BINARY [ARGS...]" % sys.argv[0], file=sys.stderr)
    sys.exit(1)

prog = args[0]
env = os.environ

prepend(env, "LD_LIBRARY_PATH", os.path.join(here, "src"))
prepend(env, "PKG_CONFIG_PATH", os.path.join(here, "src"))
prepend(env, "PATH", os.path.join(here, "tools"))
prepend(env, "PATH", os.path.join(here, "src"))

# Ensure that any 3rd party apps using libvirt.so from the build tree get
# files resolved to the build/source tree too. Typically useful for language
# bindings running tests against non-installed libvirt.
env["LIBVIRT_DIR_OVERRIDE"] = "1"

# This is a cheap way to find some use-after-free and uninitialized
# read problems when using glibc.
env["MALLOC_PERTURB_"] = "%d" % random.randint(1, 255)

env["abs_builddir"] = "@abs_builddir@"
env["abs_top_builddir"] = "@abs_top_builddir@"

modular_daemons = [
    "virtinterfaced",
    "virtlxcd",
    "virtnetworkd",
    "virtnodedevd",
    "virtnwfilterd",
    "virtproxyd",
    "virtqemud",
    "virtsecretd",
    "virtstoraged",
    "virtvboxd",
    "virtvzd",
    "virtxend",
]


def is_modular_daemon(name):
    return name in modular_daemons


def is_monolithic_daemon(name):
    return name == "libvirtd"


def is_systemd_host():
    if os.getuid() != 0:
        return False
    return os.path.exists("/run/systemd/system")


def daemon_units(name):
    return [name + suffix for suffix in [
        ".service", ".socket", "-ro.socket", "-admin.socket"]]


def is_unit_active(name):
    ret = subprocess.call(["systemctl", "is-active", "-q", name])
    return ret == 0


def change_unit(name, action):
    ret = subprocess.call(["systemctl", action, "-q", name])
    return ret == 0


def chcon(path, user, role, type):
    print("Setting file context of {} to u={}, r={}, t={}...".format(path,
                                                                     user,
                                                                     role,
                                                                     type))
    ret = subprocess.call(["chcon", "-u", user, "-r", role, "-t", type, path])
    return ret == 0


def restorecon(path):
    print("Restoring selinux context for {}...".format(path))
    ret = subprocess.call(["restorecon", path])
    return ret == 0


try_stop_units = []
if is_systemd_host():
    maybe_stopped_units = []
    for arg in args:
        name = os.path.basename(arg)
        if is_modular_daemon(name):
            # Only need to stop libvirtd or this specific modular unit
            maybe_stopped_units += daemon_units("libvirtd")
            maybe_stopped_units += daemon_units(name)
        elif is_monolithic_daemon(name):
            # Need to stop libvirtd and/or all modular units
            maybe_stopped_units += daemon_units("libvirtd")
            for entry in modular_daemons:
                maybe_stopped_units += daemon_units(entry)

    for unit in maybe_stopped_units:
        if is_unit_active(unit):
            try_stop_units.append(unit)

if len(try_stop_units) == 0 and not opts.selinux:
    # Run the program directly, replacing ourselves
    os.execvpe(prog, args, env)
else:
    stopped_units = []

    def sighandler(signum, frame):
        raise OSError("Signal %d received, terminating" % signum)

    signal.signal(signal.SIGHUP, sighandler)
    signal.signal(signal.SIGTERM, sighandler)
    signal.signal(signal.SIGQUIT, sighandler)

    try:
        dorestorecon = False
        progpath = shutil.which(prog)
        if not progpath:
            raise Exception("Can't find executable {}"
                            .format(prog))
        progpath = os.path.abspath(progpath)
        if len(try_stop_units):
            print("Temporarily stopping systemd units...")

            for unit in try_stop_units:
                print(" > %s" % unit)
                if not change_unit(unit, "stop"):
                    raise Exception("Unable to stop '%s'" % unit)

                stopped_units.append(unit)

        if opts.selinux:
            progname = os.path.basename(prog)
            # if using a wrapper command like 'gdb', setting the selinux
            # context won't work because the wrapper command will not be a
            # valid entrypoint for the virtd_t context
            if progname not in ["libvirtd", *modular_daemons]:
                raise Exception("'{}' is not recognized as a valid daemon. "
                                "Selinux process context can only be set when "
                                "executing a daemon directly without wrapper "
                                "commands".format(prog))

            if not progpath.startswith(os.path.abspath(here)):
                raise Exception("Refusing to change selinux context of file "
                                "'{}' outside build directory"
                                .format(progpath))

            if progname == "libvirtd":
                context = "virtd"
            else:
                context = progname

            # selinux won't allow us to transition to the virtd_t context from
            # e.g. the user_home_t context (the likely label of the local
            # executable file)
            if not chcon(progpath, "system_u", "object_r", f"{context}_exec_t"):
                raise Exception("Failed to change selinux context of binary")
            dorestorecon = True

            args = ['runcon',
                    '-u', 'system_u',
                    '-r', 'system_r',
                    '-t', f'{context}_t', *args]

        print("Running '%s'..." % str(" ".join(args)))
        ret = subprocess.call(args, env=env)
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print("%s" % e, file=sys.stderr)
    finally:
        if len(stopped_units):
            print("Re-starting original systemd units...")
            stopped_units.reverse()
            for unit in stopped_units:
                print(" > %s" % unit)
                if not change_unit(unit, "start"):
                    print(" ! unable to restart %s" % unit, file=sys.stderr)
        if dorestorecon:
            restorecon(progpath)
